-
Notifications
You must be signed in to change notification settings - Fork 558
Feature: Add support for L40 FusedMoE in cutlass path #1973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Amir Klein <[email protected]>
Summary of ChangesHello @amirkl94, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for L40 GPUs (SM89 architecture) within the CUTLASS FusedMoE kernels. It resolves critical compilation issues and a runtime crash related to shared memory allocation for a specific GEMM tactic on SM89, ensuring stable and efficient operation. The changes also include updates to the Python JIT compilation infrastructure to enable proper module generation and loading for L40. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdd explicit SM89 support and adjust FP4/FP8 gating across CUTLASS heuristics, MOE GEMM dispatch, and Flashinfer JIT: new SM89 NVCC flags and JIT module generator, tightened ENABLE_FP4 guards, reordered FP8 GROUPED_GEMM tile choices for SM89, and a logging-level tweak in the autotuner. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant PyCore as flashinfer/fused_moe/core.py
participant JITGen as fused_moe.gen_cutlass_fused_moe_sm89_module
participant JITCore as flashinfer/jit/core.py
participant CppDisp as moe_gemm_template_dispatch.h
participant CUTLASS as cutlass_kernels
User->>PyCore: get_cutlass_fused_moe_module(backend="89")
PyCore->>JITGen: gen_cutlass_fused_moe_sm89_module(use_fast_build)
JITGen->>JITCore: request NVCC flags (sm89_nvcc_flags + BF16/FP8 [+FP8_BLOCK_SCALE?])
JITGen->>PyCore: build_and_load() -> compiled module
PyCore-->>User: return loaded module
Note over CppDisp,CUTLASS: MOE GEMM FP format dispatch
alt FP8 GROUPED_GEMM & sm == 89
CppDisp->>CUTLASS: select SM89-specific tile configs (new order)
else FP8 GROUPED_GEMM & sm >= 120
CppDisp->>CUTLASS: select SM>=120 tile configs
else
CppDisp->>CUTLASS: fallback/default tile configs
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Potential attention points:
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (2)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for L40 GPUs (sm_89) in the CUTLASS FusedMoE path. The changes include fixing compilation issues, removing a problematic GEMM tactic for sm_89 that was causing crashes, and adding the necessary build configurations for this architecture. The changes are logical and well-implemented. I have one suggestion to improve code clarity when constructing compiler flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (1)
42-42: Fix misleading comment about SM120 vs SM100.The comment mentions
CUTLASS_ARCH_MMA_SM100_SUPPORTEDbut this function checks forCUTLASS_ARCH_MMA_SM120_SUPPORTED(line 35). Update the comment to accurately reflect the macro being checked.Apply this diff:
- return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled + return false; // CUTLASS_ARCH_MMA_SM120_SUPPORTED is set when Blackwell kernels are enabled
🧹 Nitpick comments (2)
flashinfer/jit/fused_moe.py (1)
80-88: SM89 module generation is correctly implemented.The function appropriately:
- Uses
sm89_nvcc_flagswhich excludes FP4 support for L40- Omits Hopper-specific TMA GEMM flags (correct for Ada architecture)
- Includes conditional FP8 block scale support for CUDA ≥12.8
Optional: Consider iterable unpacking for cleaner syntax.
As suggested by Ruff, you could use iterable unpacking instead of concatenation:
- nvcc_flags = sm89_nvcc_flags + [ + nvcc_flags = [ + *sm89_nvcc_flags, "-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", "-DUSING_OSS_CUTLASS_MOE_GEMM", ]This is a minor style improvement and can be deferred.
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
696-709: LGTM: Correctly reorganizes SM89 dispatch to avoid shared memory issues.The reorganized control flow properly addresses the L40 (SM89) issue by:
- Routing FP8 workloads to Sm89 kernels with runtime validation (line 703)
- Routing non-FP8 workloads to Sm80 kernels (lines 707-708)
This aligns with the kernel implementation in
moe_cutlass_kernel.hwhich shows SM89 architecture reusing Sm80 kernels for non-FP8 types, and prevents the "GPU lacks the shared memory resources" assertion mentioned in the PR objectives.Optional suggestion: Consider adding a brief comment explaining why non-FP8 on SM89 uses the Sm80 path, to help future maintainers understand the shared memory constraint that motivated this design.
Apply this diff to add a clarifying comment:
} else { + // Non-FP8 workloads on SM89 (L40) reuse Sm80 kernels to avoid + // Sm89-specific tactics that exceed shared memory limits dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>( inputs, multi_processor_count_); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h(2 hunks)flashinfer/fused_moe/core.py(2 hunks)flashinfer/jit/core.py(1 hunks)flashinfer/jit/fused_moe.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
flashinfer/fused_moe/core.py (2)
flashinfer/jit/fused_moe.py (1)
gen_cutlass_fused_moe_sm89_module(80-87)flashinfer/jit/core.py (1)
build_and_load(272-284)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (6)
__nv_fp8_e5m2(91-93)cutlass(114-116)cutlass(120-122)cutlass(127-129)cutlass(132-134)cutlass(140-142)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h (1)
cutlass(40-677)
flashinfer/jit/fused_moe.py (2)
flashinfer/jit/core.py (2)
JitSpec(185-284)gen_jit_spec(287-353)flashinfer/jit/cpp_ext.py (1)
is_cuda_version_at_least(86-87)
🪛 Ruff (0.14.1)
flashinfer/jit/fused_moe.py
81-86: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
🔇 Additional comments (6)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (1)
161-168: LGTM! Targeted fix for L40 shared memory constraints.The separation of SM89 from SM >= 120 handling correctly removes
CtaShape16x256x128_WarpShape16x64x128for L40 GPUs. This config's larger K dimension (128 vs 64) would exceed L40's shared memory capacity in GROUPED_GEMM mode, causing the assertion failure mentioned in the PR. The fix is minimal, well-scoped, and maintains full functionality for newer architectures.flashinfer/jit/core.py (1)
93-96: SM89 flags correctly omit FP4 support.The implementation appropriately excludes
-DFLASHINFER_ENABLE_FP4_E2M1for SM89 (L40), unlike other architectures that usecommon_nvcc_flags. This aligns with the PR objective of fixing compilation issues for L40 by removing problematic GEMM tactics.flashinfer/fused_moe/core.py (2)
40-40: LGTM!Import added correctly to support SM89 module generation.
289-290: LGTM!Backend dispatch for SM89 follows the established pattern and correctly invokes the new module generation function.
flashinfer/jit/fused_moe.py (1)
21-27: LGTM!Import correctly updated to include
sm89_nvcc_flagsneeded for the new SM89 module generation function.csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
692-694: LGTM: Proper FP4 validation for SM89.The explicit compile-time check correctly prevents FP4 usage on SM < 90, consistent with the earlier validation for SM75-80 at lines 685-688.
.../nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h
Outdated
Show resolved
Hide resolved
.../nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h
Outdated
Show resolved
Hide resolved
Signed-off-by: Amir Klein <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (1)
35-36: LGTM: Guard is appropriate for FP4-only SM120 specialization.The
ENABLE_FP4guard is correct here. SinceisValidSM120MOESpecialisationonly supports__nv_fp4_e2m1types (lines 37-40), requiringENABLE_FP4at the preprocessor level prevents compilation errors when FP4 support is not available.
.../nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, mainly just want to understand why we need to disable the tile shape.
Coderabbit's comment about more granular Fp4 guard might make sense, but I assume that if we are compiling with blackwell support we should also have FP4
| CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, | ||
| CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, | ||
| CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, | ||
| CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any SM89 GPUs that can support the CtaShape16x256x128_WarpShape16x64x128, or is this an SM120 addition?
Having a small M value like this helps with low latency cases, so I'd want to understand why its not supported before disabling it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the very least can you leave a comment saying what the difference between the two lists are, so people dont have to manually compare the items
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any SM89 GPUs that can support the
CtaShape16x256x128_WarpShape16x64x128, or is this an SM120 addition? Having a small M value like this helps with low latency cases, so I'd want to understand why its not supported before disabling it
This is a removal from the sm89 path. When I tested it on an L40 GPU I got Assertion failed: GPU lacks the shared memory resources to run GroupedGEMM kernel .
It might be that on other sm89 GPUs it will pass, the main issue is that this was the default tactic that was chosen when trying to use FusedMoE, I believe that moving it to be the last will also fix my issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried moving this tile config to be the last one and now the default tactic won't fail on l40. The issue is that if autotuner is on then the tactics that use this tile config will report an error with a stacktrace which looks bad.
@yzh119 do you think it'll be ok to change the errors that happen in the autotuner to be debug logs? Otherwise it means users will get spammed with error messages when they run autotuning on L40 FusedMoE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the autotuner should still output warnings, but just make them say "Skipping tactic x due to error. This tactic may not be supported by the current GPU architecture".
That said I know there is a difference of opinion on whether we should proactively filter them as you have done here, the argument being that we should be able to do the due diligence to determine what tactics are supported so that we can raise an error when a tactic fails when it shouldn't. So I can see either side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I think that since this PR is somewhat time critical I'd rather move the faulty tile config to be the last so as not to remove it, and change the log to be a warning as you suggested and maybe a debug log with the stacktrace.
I'll also open an issue to take a deeper look into regarding if this specific tile config is relevant for sm89 or not as I think it might take me some time.
@djns99 , @yzh119 Does that sound ok to you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with that resolution
|
There are some compilation errors on CI, @amirkl94 would you mind taking a look? |
Signed-off-by: Amir Klein <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
691-699: LGTM: FP8/W4AFP8 routing to SM89 is correct.The restructured logic correctly routes FP8 and W4AFP8 workloads to the SM89 architecture for the SM 80-89 range, which is essential for L40 support. The runtime check at line 697 appropriately prevents FP8 usage on SM80-88 hardware that lacks FP8 support. The static_assert placement within the ENABLE_FP8 block is consistent with the codebase pattern (see lines 244-247).
Minor: Consider clarifying the error message.
The error message at line 697 mentions "fp8" but the condition includes
use_w4afp8. For clarity, consider updating the message to reflect both supported types.- TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 and w4afp8 are only supported with sm == 89");
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (6)
__nv_fp8_e5m2(91-93)cutlass(114-116)cutlass(120-122)cutlass(127-129)cutlass(132-134)cutlass(140-142)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h (1)
cutlass(40-677)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
multi_processor_count_(341-341)
🔇 Additional comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
700-712: LGTM: Non-FP8 handling and FP4 guard are correct.The else branch correctly handles non-FP8 cases for SM 80-89:
- FP4 types are appropriately blocked on SM < 90 (line 703)
- Non-FP8 workloads dispatch to SM80 architecture, which is an intentional kernel reuse strategy for SM89 as documented in the kernel implementation
The duplication of the dispatch call (lines 705-706 and 709-710) is necessary for conditional compilation when ENABLE_FP4 is not defined.
Pushed a commit to fix this |
Signed-off-by: Amir Klein <[email protected]>
…ses on L40 Signed-off-by: Amir Klein <[email protected]>
Signed-off-by: Amir Klein <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h(2 hunks)flashinfer/autotuner.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (5)
cutlass(114-116)cutlass(120-122)cutlass(127-129)cutlass(132-134)cutlass(140-142)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h (1)
cutlass(40-677)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/autotuner.py (1)
485-492: LGTM! Appropriate logging level adjustment.The change from error to warning/debug levels is sensible for profiling failures during autotuning. Since failed tactics are automatically skipped (time set to infinity) and the fallback mechanism ensures operation continues, treating these as errors would create unnecessary log noise. The warning level appropriately notifies users that a tactic was skipped, while debug level preserves diagnostic details for troubleshooting.
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (1)
35-36: LGTM! Guard correctly requires ENABLE_FP4.The guard appropriately requires both
CUTLASS_ARCH_MMA_SM120_SUPPORTEDandENABLE_FP4since this specialization exclusively validates FP4 types (__nv_fp4_e2m1), which are only defined when FP4 support is compiled in.
.../nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h
Outdated
Show resolved
Hide resolved
Signed-off-by: Amir Klein <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, should be ready to merge once CI passed.
|
/bot run |
|
LGTM |
|
[SUCCESS] Pipeline #37466918: 13/17 passed |
📌 Description
Fixed a few compilation issues for L40, and removed 1 gemm tactic for
sm == 89that crashes due to:🧪 Tests
Ran
pytest tests/moe/test_trtllm_cutlass_fused_moe.pymanually on an L40 GPU and verified all tests passed.Summary by CodeRabbit
New Features
Bug Fixes / Compatibility
Performance
Chores